import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib import collections as matcoll
import os
import math

# This file implements a plotting style using Matplotlib and seaborn,
# and actually makes plots of Local SGD variants' loss versus communication rounds

###################################################################################################
# Tweaking seaborn to make our curves more beautiful :)
# Seaborn allows us to actually change matplotlob parameters through it
# Inspired by: https://towardsdatascience.com/making-matplotlib-beautiful-by-default-d0d41e3534fd

sns.set(font='Franklin Gothic Book',
        rc={
            'axes.axisbelow': False,
            'axes.edgecolor': 'lightgrey',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'dimgrey',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'figure.facecolor': 'white',
            'lines.solid_capstyle': 'round',
            'patch.edgecolor': 'w',
            'patch.force_edgecolor': True,
            'text.color': 'black',
            'xtick.bottom': False,
            'xtick.color': 'dimgrey',
            'xtick.direction': 'out',
            'xtick.top': False,
            'ytick.color': 'dimgrey',
            'ytick.direction': 'out',
            'ytick.left': False,
            'ytick.right': False})

# setting some global font sizes
sns.set_context("notebook", rc={"font.size": 16,
                                "axes.titlesize": 18,
                                "axes.labelsize": 18})

# Defining colour names
CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
CB91_Black = '#000000'

# Setting default colour for plotting and cycling through them
color_list = [CB91_Blue, CB91_Black, CB91_Green,
              CB91_Purple, CB91_Amber, CB91_Pink]
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=color_list)
plt.rcParams.update({'lines.markeredgewidth': 1})

###########################################################################################

homogeneities = [0.1, 0.35, 0.60, 0.85]

lstorm = {"16":[None, None, None, None], "32":[None, None, None, None]}
lsarah = {"16":[None, None, None, None], "32":[None, None, None, None]}
mbsarah = {"16":[None, None, None, None], "32":[None, None, None, None]}
mbsgd = {"16":[None, None, None, None], "32":[None, None, None, None]}
fedavg = {"16":[None, None, None, None], "32":[None, None, None, None]}
scaffold = {"16":[None, None, None, None], "32":[None, None, None, None]}


for K in ["16", "32"]:
    i=0
    for homogeneity in homogeneities:
        path = f"../results/03/fc/cifar10/homogeneity={homogeneity}/lstorm_K={K}_b=16/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                lstorm[K][i] = best
        
        path = f"../results/01/fc/cifar10/homogeneity={homogeneity}/lsarah_K={K}_b=16/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                lsarah[K][i] = best
        
        path = f"../original/bvr_l_sgd_20220407/results/20210413/fc/cifar10/homogeneity={homogeneity}/lsarah_K=1_b={16*int(K)}/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                mbsarah[K][i] = best
        
        path = f"../original/bvr_l_sgd_20220407/results/20210413/fc/cifar10/homogeneity={homogeneity}/lsgd_K=1_b={16*int(K)}/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                mbsgd[K][i] = best
        
        path = f"../original/bvr_l_sgd_20220407/results/20210413/fc/cifar10/homogeneity={homogeneity}/lsgd_K={K}_b=16/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                fedavg[K][i] = best

        path = f"../original/bvr_l_sgd_20220407/results/20210413/fc/cifar10/homogeneity={homogeneity}/scaffold_K={K}_b=16/"
        best = np.inf
        for file in os.listdir(path):
            history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
            if len(history)>0 and np.min(history) < best:
                best = np.min(history)
                scaffold[K][i] = best
                
        i+=1
    

fig, axes = plt.subplots(1, 1, figsize=(8, 6))

x = np.arange(4) + 1

labels = homogeneities

# axes[0].set_xticks([0, 1, 2, 3, 4])
# axes[0].set_xticklabels(['0', '0.1', '0.35', '0.60', '0.85'])

# axes[0].plot(x, np.log10(np.array(lsarah["16"])), label="BVR-LSGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].plot(x, np.log10(np.array(lstorm["16"])), label="CE-LSGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].plot(x, np.log10(np.array(mbsarah["16"])), label="MB-SARAH", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].plot(x, np.log10(np.array(mbsgd["16"])), label="MB-SGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].plot(x, np.log10(np.array(fedavg["16"])), label="FedAVG", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].plot(x, np.log10(np.array(scaffold["16"])), label="Scaffold", marker="o", linestyle='dashed', linewidth=3, markersize=12)
# axes[0].set(title=f"$K={16}$", xlabel= "q", ylabel="$\log_{10}$(Train Loss)")
# axes[0].legend()


axes.set_xticks([0, 1, 2, 3, 4])
axes.set_xticklabels(['0', '0.1', '0.35', '0.60', '0.85'])

axes.plot(x, np.log10(np.array(lsarah["32"])), label="BVR-LSGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.plot(x, np.log10(np.array(lstorm["32"])), label="CE-LSGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.plot(x, np.log10(np.array(mbsarah["32"])), label="MB-SARAH", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.plot(x, np.log10(np.array(mbsgd["32"])), label="MB-SGD", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.plot(x, np.log10(np.array(fedavg["32"])), label="FedAVG", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.plot(x, np.log10(np.array(scaffold["32"])), label="Scaffold", marker="o", linestyle='dashed', linewidth=3, markersize=12)
axes.set(title=f"$K={32}$", xlabel= "q", ylabel="$\log_{10}$(Train Loss)")
axes.legend(ncol=3, bbox_to_anchor=(0.85, -0.15))

plt.tight_layout()    
plt.savefig(f"figs/hetero.png", dpi=150)    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    